import numpy as np  # 导入numpy库，用于数组和矩阵运算
import pandas as pd  # 导入pandas库，用于数据处理和创建数据表
import matplotlib.pyplot as plt  # 导入matplotlib库，用于绘图
import time  # 导入time库，用于控制程序暂停时间

# 定义强化学习的一些超参数
ALPHA = 0.1  # 学习率，控制更新Q值的幅度
GAMMA = 0.95  # 折扣因子，控制未来奖励的重要性
EPSILION = 0.9  # epsilon值，用于ε-贪婪策略，控制探索与利用的权衡
N_STATE = 6  # 状态的数量，表示状态空间的大小
ACTIONS = ['left', 'right']  # 可能的动作列表，表示智能体的可选行为
MAX_EPISODES = 200  # 最大的训练轮次，表示最大实验次数
FRESH_TIME = 0.1  # 控制环境更新的时间间隔，用于显示训练过程

# 定义Q表的构建函数
def build_q_table(n_state, actions):
    # 创建一个Q表，行代表状态，列代表动作，初始时所有Q值为0
    q_table = pd.DataFrame(
        np.zeros((n_state, len(actions))),  # 初始化一个全0的表格，大小为状态数x动作数
        np.arange(n_state),  # 状态的索引
        actions  # 动作的名称
    )
    return q_table  # 返回初始化后的Q表

# 定义选择动作的函数
def choose_action(state, q_table):
    # epslion - greedy策略
    state_action = q_table.loc[state, :]  # 获取当前状态下所有可能动作的Q值
    if np.random.uniform() > EPSILION or (state_action == 0).all():  # 如果随机数大于epsilon或所有动作Q值为0
        action_name = np.random.choice(ACTIONS)  # 选择一个随机动作（探索）
    else:
        action_name = state_action.idxmax()  # 否则选择Q值最大的动作（利用）
    return action_name  # 返回选择的动作

# 定义环境反馈的函数
def get_env_feedback(state, action):
    # 根据当前状态和动作来返回下一个状态和奖励
    if action == 'right':  # 如果选择向右移动
        if state == N_STATE - 2:  # 如果当前状态是倒数第二个状态
            next_state = 'terminal'  # 到达终止状态
            reward = 1  # 奖励为1
        else:
            next_state = state + 1  # 否则状态加1
            reward = -0.5  # 奖励为-0.5
    else:  # 如果选择向左移动
        if state == 0:  # 如果当前状态是最左边的状态
            next_state = 0  # 保持在原地
        else:
            next_state = state - 1  # 否则状态减1
        reward = -0.5  # 奖励为-0.5
    return next_state, reward  # 返回下一个状态和奖励

# 定义环境更新的函数
def update_env(state, episode, step_counter):
    # 生成一个表示环境的字符串，'-'表示空地，'T'表示终止状态
    env = ['-'] * (N_STATE - 1) + ['T']
    if state == 'terminal':  # 如果到达终止状态
        print("Episode {}, the total step is {}".format(episode + 1, step_counter))  # 打印当前回合和步骤
        final_env = ['-'] * (N_STATE - 1) + ['T']  # 环境没有变化
        return True, step_counter  # 终止回合，返回True
    else:
        env[state] = '*'  # 将当前状态位置标记为'*'
        env = ''.join(env)  # 将环境列表转化为字符串
        print(env)  # 打印当前环境的状态
        time.sleep(FRESH_TIME)  # 暂停程序FRESH_TIME秒，模拟环境变化的延迟
        return False, step_counter  # 没有到达终止状态，返回False

# 定义SARSA学习算法的函数
def sarsa_learning():
    q_table = build_q_table(N_STATE, ACTIONS)  # 创建一个Q表
    step_counter_times = []  # 用于记录每个回合的步骤数
    for episode in range(MAX_EPISODES):  # 进行最大回合数的学习
        state = 0  # 初始状态设为0
        is_terminal = False  # 初始状态不是终止状态
        step_counter = 0  # 初始步骤计数为0
        update_env(state, episode, step_counter)  # 更新环境并显示
        while not is_terminal:  # 当未到达终止状态时继续学习
            action = choose_action(state, q_table)  # 根据当前状态选择动作
            next_state, reward = get_env_feedback(state, action)  # 获取环境反馈（下一个状态和奖励）
            if next_state != 'terminal':  # 如果不是终止状态
                next_action = choose_action(next_state, q_table)  # 选择下一个状态的动作（SARSA更新方法）
            else:
                next_action = action  # 如果是终止状态，动作不再改变
            next_q = q_table.loc[state, action]  # 获取当前Q值

            if next_state == 'terminal':  # 如果到达终止状态
                is_terminal = True  # 设置为终止状态
                q_target = reward  # 目标Q值为奖励
            else:
                delta = reward + GAMMA * q_table.loc[next_state, next_action] - q_table.loc[state, action]  # SARSA更新公式
                q_table.loc[state, action] += ALPHA * delta  # 更新Q表中的值
            state = next_state  # 更新当前状态为下一个状态
            is_terminal, steps = update_env(state, episode, step_counter + 1)  # 更新环境并检查是否终止
            step_counter += 1  # 增加步骤计数
            if is_terminal:  # 如果到达终止状态，记录步骤数
                step_counter_times.append(steps)
                
    return q_table, step_counter_times  # 返回更新后的Q表和每回合的步骤数

# 主函数入口
if __name__ == '__main__':
    q_table, step_counter_times = sarsa_learning()  # 执行SARSA学习
    print("Q table\n{}\n".format(q_table))  # 打印最终的Q表
    print('end')  # 输出训练结束
    plt.plot(step_counter_times, 'g-')  # 绘制每回合的步骤数变化曲线
    plt.ylabel("steps")  # 设置y轴标签
    plt.title("Sarsa Algorithm")  # 设置图标题
    plt.show()  # 显示图形
    print("The step_counter_times is {}".format(step_counter_times))  # 打印每回合的步骤数

